%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% This function computes the discrete state markov chain approximation for
% a continuous state VAR(1) process using Tauchen (1986) method.
% VAR is denoted as: Y=Rho*Y+E
% where Y is a (Dim X 1) vector, Rho is a (Dim X Dim) matrix and
% E~N(0,Omega) and Omega is a (Dim X Dim) symetric posi-definite matrix.
% Function: 
% [StateGrid,TrProbMat,InvDist,GridCell,IndGrid]=TauchenVector(Width,GridNum,Rho,Omega,TrProbEps)
% Input variables: 
%   Width:      Width of the grid (in the unit of ergodic dist. std.)
%               Vector  (Dim x 1)
%   GridNum:    Number of the discrete state space grid
%               Vector  (Dim x 1)
%   Rho:        VAR(1) coef. matrix of the process
%   Omgea:      Covariance matrix of the VAR(1) process innovation term
%   TrProbEps:  (Optional) threshold for transition probability, if the
%               prob.<TrProbEps, it will be reassigned as 0 and the 
%               transition prob. vector will be re-normalized.
% Output variables:
%   StateGrid:      discrete state space grid
%                   Matrix (prod(GridNum) X Dim)
%                   Format: [Y1';Y2';...;YN'] where N=prod(GridNum)
%   TrProbMat:      transition prob. matrix
%                   Sparse Matrix (prod(GridNum) X prod(GridNum))
%   GridCell:       Grid points on each seperate dimension
%   IndGrid:        has similar structure with StateGrid, but the exact
%                   state values are replaced by the state index on each
%                   sparate dimension
% Special Note: TrProbMat_{i,j}=Pr(Y1=Yi,Y0=Yj)
% Author: Xing Guo, xingguo@umich.edu
% Version: 9/9/2017
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [StateGrid,TrProbMat,InvDist,GridCell,IndGrid]=TauchenVector(Width,GridNum,Rho,Omega,TrProbEps)

%% Preliminary check of dimension
Dim         =   length(Width);
if length(GridNum)~=Dim 
    error('Check the dimension: GridNum\n');
end
if any([size(Rho),size(Omega)]~=Dim)
    error('Check the dimension: Rho and Omega\n');
end
if nargin<5
    TrProbEps       =   0;
end
% Width       =   [9;4;5];
% GridNum     =   [5;6;7];
% Rho         =   [1,2,3;2,3,4;3,4,5];
% Lambda      =   [1,0,0;2,3,0;3,4,5];
% Omega       =   Lambda*Lambda';
%% Standarize VAR to have a uncorrelated error term
Lambda      =   chol(Omega,'lower');
% Transformed VAR(1) cofficient matrix
At          =   Lambda\Rho*Lambda;
% Ergodic covariance matrix of the transformed VAR
Sigma       =   (eye(size(At,1)^2)-kron(At,At))...
                \reshape(eye(size(At,1)),[size(At,1)*size(At,1),1]);
Sigma       =   reshape(Sigma,size(At));
sigma       =   sqrt(diag(Sigma));

%% Construct the state grid
% Grid jumping step length and boundaries on each dimension
w           =   sigma.*Width./(GridNum-1);
LowerBound  =   -Width/2.*sigma;
UpperBound  =   -LowerBound;
% Grid on each dimension
GridCell    =   cell(Dim,1);
GridIndCell =   cell(Dim,1);
GridUppCell =   cell(Dim,1);
GridLowCell =   cell(Dim,1);
for i=1:Dim
    GridCell{i}     =   linspace(LowerBound(i),UpperBound(i),GridNum(i))';
    GridIndCell{i}  =   (1:1:GridNum(i))';
    GridUppCell{i}  =   GridCell{i}+w(i)/2;
    GridLowCell{i}  =   GridCell{i}-w(i)/2;
    GridUppCell{i}(end)     =   Inf;
    GridLowCell{i}(1)       =   -Inf;
end

% Cartesian product of the grid points on each dimension
N           =   prod(GridNum);
Grid        =   zeros(Dim,N);
IndGrid     =   zeros(Dim,N);
GridUpp     =   zeros(Dim,N);
GridLow     =   zeros(Dim,N);
RepeatNum   =   N;
for i=1:Dim
    RepeatNum       =   RepeatNum/GridNum(i);
    Grid(i,:)       =   repmat(reshape(repmat(GridCell{i},[1,RepeatNum])',...
                                       [1,RepeatNum*GridNum(i)]),...
                               [1,N/RepeatNum/GridNum(i)]);
    IndGrid(i,:)    =   repmat(reshape(repmat(GridIndCell{i},[1,RepeatNum])',...
                                       [1,RepeatNum*GridNum(i)]),...
                               [1,N/RepeatNum/GridNum(i)]);
    GridUpp(i,:)    =   repmat(reshape(repmat(GridUppCell{i},[1,RepeatNum])',...
                                       [1,RepeatNum*GridNum(i)]),...
                               [1,N/RepeatNum/GridNum(i)]);
    GridLow(i,:)    =   repmat(reshape(repmat(GridLowCell{i},[1,RepeatNum])',...
                                       [1,RepeatNum*GridNum(i)]),...
                               [1,N/RepeatNum/GridNum(i)]);
end

%% Construct the transition probability matrix
AtGrid      =   At*Grid;

TrProbMat   =   zeros(N,N);

for i=1:N
    temp            =   repmat(AtGrid(:,i),[1,N]);
    prob            =   prod( normcdf(GridUpp-temp)-normcdf(GridLow-temp),1 );
    if TrProbEps>0
        prob(prob<TrProbEps)    =   0;
        TrProbMat(i,:)  =   prob/sum(prob,2);
    else 
        TrProbMat(i,:)  =   prob;
    end
end

TrProbMat   =   sparse(TrProbMat');

StateGrid   =   (Lambda*Grid)';
GridCell    =   cell(Dim,1);
for i=1:Dim
    GridCell{i}     =   sort(unique(StateGrid(:,i)));
end
IndGrid     =   IndGrid';
[InvDist,~] =   eigs(TrProbMat,1,1+1e-6);
if sum(InvDist)<0
    InvDist     =   -InvDist;
end
InvDist(InvDist<0)      =   0;
InvDist         =   InvDist/sum(InvDist);